import numpy as np
import numpy.linalg as LA
from sympy import re, im, I, E, Symbol, sqrt
import argparse
from sympy.functions.elementary.miscellaneous import cbrt


def g_CauchyK_num(S):
    z = Symbol('z')
    ret = 0
    N = len(S)
    
    for j in range(N):

        ret += 1/(z + S[j] - I*np.sqrt(1/(2*N)) )
        ret += 1/(z - S[j]- I*np.sqrt(1/(2*N)) )
    
    return ret/(2*N)

def Estimator(S_s, gS, SNR, a, c):
    
    N = len(S_s)
    
    output = np.zeros(N)
    
    z = Symbol('z')
    
    for i in range(N):
        
        #### optimal singularvalue for Y
        zz = S_s[i] -  I*np.sqrt(1/(2*N))
        
        gS_eval = gS.subs(z,zz).evalf()
        
        q4 = -3*c +   (3**(2/3))*( a*( -2 + a*( -2 + c**2))*zz - 2*(gS_eval**2)*zz + 2*gS_eval*( 1 + a*( -1 + zz**2)) )/ \
        ((a**2)*zz* cbrt( 9*c*(a*zz - gS_eval)*(gS_eval*zz+a-1)/((a**2)*zz) + sqrt(3)\
            *sqrt( (27*(a**2)*(c**2)*zz*((gS_eval-a*zz)**2)*((-1+a+gS_eval*zz)**2) + (a*( 2 - a*( -2 + c**2 ))*zz + 2*(gS_eval**2)*zz \
        -2*gS_eval*( 1 + a*( -1 + zz**2)))**3 )/( (a**6)*(zz**3) ) ) ) )\
        + cbrt( 27*c*(a*zz - gS_eval)*(gS_eval*zz+a-1)/((a**2)*zz) + 3*sqrt(3)\
            *sqrt( (27*(a**2)*(c**2)*zz*((gS_eval-a*zz)**2)*((-1+a+gS_eval*zz)**2) + (a*( 2 - a*( -2 + c**2 ))*zz + 2*(gS_eval**2)*zz \
        -2*gS_eval*( 1 + a*( -1 + zz**2)))**3 )/( (a**6)*(zz**3) ) ) )
        
        q4 = q4/6
        
        q4 = q4.evalf()
        
        output[i] =  (a*im(q4)/(np.sqrt(SNR) * im(gS_eval))).evalf()
    
    return output


def main():
    
    z = Symbol('z')
    
    p = argparse.ArgumentParser()
    p.add_argument('-M', type=int)
    args = p.parse_args()


    N = 2000
    M = args.M
    a = N/M
    SNR = 5
    
    Ex = 10
    
        
    E_oracle = np.zeros(Ex)
    E_RIE = np.zeros(Ex)

    for i in range(Ex):
        

        Y = np.random.randn(N,M)
        Y = Y/np.sqrt(N)

    
                ## Noise
        c = 3
        X = np.triu(np.random.normal(0, 1, (N,N)))
        X = X + np.transpose(X) + np.diag(np.random.normal(loc=0, scale=np.sqrt(2), size=(N)))
        X = X/np.sqrt(N)
        X = X + c*np.eye(N)
    
        W = np.random.randn(N,M)
        W = W/np.sqrt(N)


        ### Observation
        S = np.sqrt(SNR) * X @ Y + W
    
        ### SVD
        U_s, S_s , Vh_s = LA.svd(S)

        gS = g_CauchyK_num(S_s)

        ### Oracle Estimator for Y
        e_hat_oracle = np.zeros(M)
        
        Y_norm = LA.norm(Y)**2
        
        for k in range(M):
            e_hat_oracle[k] = np.transpose(U_s[:,k])@Y@Vh_s[k,:]
                
        SV_oracle = np.vstack((np.diag(e_hat_oracle),np.zeros((N-M,M))))
        
        Y_hat_oracle = U_s@SV_oracle@ Vh_s
        
        E_oracle[i] = ( LA.norm(Y-Y_hat_oracle)**2 ) / Y_norm



        #### RIE for Y
        e_hat = Estimator(S_s, gS, SNR, a, c)
        
        SV_RIE = np.vstack((np.diag(e_hat),np.zeros((N-M,M))))
        Y_hat = U_s@SV_RIE@ Vh_s

        E_RIE[i] = ( LA.norm(Y-Y_hat)**2) / Y_norm


    filename = 'Y-Gaussian_N=2000_M='+str(M)+'_SNR='+str(SNR)+'_Oracle.npy'
    np.save( filename, E_oracle)
    
    
    filename = 'Y-Gaussian_N=2000_M='+str(M)+'_SNR='+str(SNR)+'_RIE.npy'
    np.save( filename, E_RIE)

if __name__ == "__main__":
    main()
    
